{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Capsule\n",
    "\n",
    "核心思想是： 每个capsule代表一个特征。\n",
    "\n",
    "具体的解释，请看 [揭开迷雾，来一顿美味的Capsule盛宴 By 苏剑林](https://kexue.fm/archives/4819)\n",
    "\n",
    "需要强调的是，这里使用的 0.3.0 版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入各种包\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为笔者的愚钝，理解 capsule 确实花费了不少时间，于是把理解的过程一步步写下来，如果嫌弃麻烦的话，可以直接翻拉倒最下面查看完整的栗子"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第一步： 我们不设置 batch 每次只输入一个样本进入 capsule layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CapsuleLayer_v1(nn.Module):\n",
    "    def __init__(self, in_cap_num, in_dim, out_cap_num, out_dim, routings):\n",
    "        super(CapsuleLayer_v1, self).__init__()\n",
    "        self.out_cap_num = out_cap_num  # 下一层 capsule 的个数\n",
    "        self.in_cap_num = in_cap_num  # 输入的 capsule 的个数\n",
    "        self.routings = routings\n",
    "        self.in_dim = in_dim  # 输入　capsule 的维度\n",
    "        self.out_dim = out_dim  # 输出　capsule 的维度\n",
    "        \n",
    "        # 变换矩阵\n",
    "        self.W = nn.Parameter(torch.randn(in_cap_num, out_cap_num, out_dim, in_dim)) # \n",
    "        \n",
    "    def forward(self, u_vecs):\n",
    "        \"\"\"\n",
    "            考虑简单情况，每次都是一个样本，也就是说，我们的输入 u_vecs 是 (in_capsule_num, capsule_dim)\n",
    "        \"\"\"        \n",
    "        # 完成变换矩阵\n",
    "        \n",
    "        b = Variable(torch.zeros(self.out_cap_num, self.in_cap_num))\n",
    "        u_hat = Variable(torch.zeros((self.out_cap_num, self.in_cap_num, self.out_dim)))\n",
    "        \n",
    "        #　为了方便理解，才这样写的，而且这里是没有　batach　存在的情况\n",
    "        for j in range(self.out_cap_num):\n",
    "            for i in range(self.in_cap_num):\n",
    "                u_hat[j, i] = torch.mm(self.W[i, j], u_vecs[i].view(-1,1))\n",
    "        # dynamic routing\n",
    "        for i in range(self.routings):\n",
    "            c = F.softmax(b, dim=1) # out_cap_num*input_capsule_num (表示概率)\n",
    "            s = torch.matmul(c.unsqueeze_(1), u_hat).squeeze_(1) # out_cap_num*out_dim\n",
    "            v = self.squash(s)\n",
    "            b = b + torch.matmul(u_hat, v.unsqueeze(2)).squeeze_(2)\n",
    "        return v\n",
    "\n",
    "    # 定义 squash 函数\n",
    "    def squash(self, x, p=2, dim=1, keepdim=True):\n",
    "        \"\"\"\n",
    "            params: x (num*feature), p: 几范数, dim: 对哪个维度求范数, keepdim: 保持维度一致\n",
    "            return: squash_x (b*m)\n",
    "        \"\"\"\n",
    "        squash_norm = torch.norm(x, p, dim, keepdim)\n",
    "        scale = torch.sqrt(squash_norm) / (1 + squash_norm)\n",
    "        return scale * x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试用例1：\n",
    "\n",
    "————————————————————————————————————————————————————————\n",
    "\n",
    "```\n",
    "cap_v1 = CapsuleLayer_v1(5, 128, 4, 16, 3)\n",
    "\n",
    "u_vecs = Variable(torch.randn((5, 128)))\n",
    "\n",
    "得到的 下一层的 capsule 的 size 应该是 (4, 16)\n",
    "```\n",
    "\n",
    "————————————————————————————————————————————————————————"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 16])\n"
     ]
    }
   ],
   "source": [
    "cap_v1 = CapsuleLayer_v1(5, 128, 4, 16, 3)\n",
    "\n",
    "u_vecs_1 = Variable(torch.randn((5, 128)))\n",
    "\n",
    "v1 = cap_v1(u_vecs_1)\n",
    "\n",
    "print(v1.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 第二步： 我们设置 batch，每次都是输入 batch 个样本进入 capsule layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CapsuleLayer_v2(nn.Module):\n",
    "    def __init__(self, in_cap_num, in_dim, out_cap_num, out_dim, routings, batch):\n",
    "        super(CapsuleLayer_v2, self).__init__()\n",
    "        self.out_cap_num = out_cap_num  # 下一层 capsule 的个数\n",
    "        self.in_cap_num = in_cap_num  # 输入的 capsule 的个数\n",
    "        self.routings = routings\n",
    "        self.in_dim = in_dim  # 输入　capsule 的维度\n",
    "        self.out_dim = out_dim  # 输出　capsule 的维度\n",
    "        self.batch = batch # batch size\n",
    "        # 变换矩阵\n",
    "        self.W = nn.Parameter(torch.randn(batch, in_cap_num, out_cap_num, out_dim, in_dim)) # \n",
    "        \n",
    "    def forward(self, u_vecs):\n",
    "        \"\"\"\n",
    "            此时，每次都是 batch 个样本，也就是说，我们的输入 u_vecs 是 (batch, in_capsule_num, capsule_dim)\n",
    "        \"\"\"        \n",
    "        # 完成变换矩阵\n",
    "        \n",
    "        b = Variable(torch.zeros(self.batch, self.out_cap_num, self.in_cap_num))\n",
    "        u_hat = Variable(torch.zeros((self.batch, self.out_cap_num, self.in_cap_num, self.out_dim)))\n",
    "        \n",
    "        #　为了方便理解，我们在这里使用三个循环... 讲道理，这样写有很大的问题，但是为了便于理解，我们先这样写。\n",
    "        for k in range(self.batch):\n",
    "            for j in range(self.out_cap_num):\n",
    "                for i in range(self.in_cap_num):                    \n",
    "                    # u_hat[k, j, i] = [out_dim. in_dim] * [in_dim, 1]\n",
    "                    u_hat[k, j, i] = torch.mm(self.W[k, i, j], u_vecs[k, i].view(-1,1))\n",
    "        # dynamic routing\n",
    "        for i in range(self.routings):\n",
    "            c = F.softmax(b, dim=2) # batch*out_cap_num*input_capsule_num (表示概率)\n",
    "            s = torch.matmul(c.unsqueeze_(2), u_hat).squeeze_(2) # batch*out_cap_num*out_dim\n",
    "            v = self.squash(s, dim=2) # batch*out_cap_num*out_dim\n",
    "            b = b + torch.matmul(u_hat, v.unsqueeze(3)).squeeze_(3)\n",
    "        return v\n",
    "\n",
    "    # 定义 squash 函数\n",
    "    def squash(self, x, p=2, dim=1, keepdim=True):\n",
    "        squash_norm = torch.norm(x, p, dim, keepdim)\n",
    "        scale = torch.sqrt(squash_norm) / (1 + squash_norm)\n",
    "        return scale * x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试用例2:\n",
    "\n",
    "————————————————————————————————————————————————————————\n",
    "\n",
    "```\n",
    "cap_v2 = CapsuleLayer_v2(5, 128, 4, 16, 3, 64)\n",
    "\n",
    "u_vecs_2 = Variable(torch.randn((64, 5, 128)))\n",
    "\n",
    "得到的 下一层的 capsule 的 size 应该是 (64, 4, 16)\n",
    "```\n",
    "\n",
    "————————————————————————————————————————————————————————"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 4, 16])\n"
     ]
    }
   ],
   "source": [
    "cap_v2 = CapsuleLayer_v2(5, 128, 4, 16, 3, 64)\n",
    "\n",
    "u_vecs_2 = Variable(torch.randn((64, 5, 128)))\n",
    "\n",
    "v2 = cap_v2(u_vecs_2)\n",
    "\n",
    "print(v2.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 第三步： 我们尝试不去使用 for 循环\n",
    "\n",
    "从上面两个过程，我们可以看出，计算是上的难点在于 我们要进行三个 for 循环，这可以说是相当耗时间的，所以，我们需要想办法解决掉 for 循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
